#!/usr/bin/env python2

#GoodFET Zensys SPI flash interface
#by Jean-Michel Picod

import sys
from GoodFETSPI import GoodFETZensys

if len(sys.argv) == 1:
    print "Usage: %s verb [objects]\n" % sys.argv[0]
    print "%s info" % sys.argv[0]
    print "%s infodata" % sys.argv[0]
    print "%s writeinfodata 0x$infodata [xtal_freq]" % sys.argv[0]
    print "%s lockbits" % sys.argv[0]
    print "%s protect_flash" % sys.argv[0]
    print "%s protect_page0" % sys.argv[0]
    print "%s bootsize $size" % sys.argv[0]
    print "%s dump $foo.rom [0x$start 0x$stop]" % sys.argv[0]
    print "%s erase all [xtal_freq]" % sys.argv[0]
    print "%s erase flash [xtal_freq]" %sys.argv[0]
    print "%s erase page $page_number [xtal_freq]" % sys.argv[0]
    print "%s flash $foo.rom [xtal_freq 0x$start 0x$stop]" % sys.argv[0]
    print "%s verify $foo.rom [0x$start 0x$stop]" % sys.argv[0]
    sys.exit()

#Initialize FET and set baud rate
client = GoodFETZensys()
client.serInit()

client.SPIsetup()
client.programming_enable()

if sys.argv[1] == "info":
    data = client.read_signature()
    print "Jedec: %s" % (" ".join(["%02x" % x for x in data[:5]]))
    print "Chip type: %02x" % data[5]
    print "Chip revision: %02x" % data[6]
    if data[5] == 0 and data[6] in range(6):
        print "Identified ZW0201 chip"
    if data[5] == 0 and data[6] in (6, 7):
        print "Identified ZW0301 chip"

if sys.argv[1] == "infodata":
    data = client.read_infodata()
    print "Infodata"
    print "%s    %s" % (" ".join(["%02x" % x for x in data[:4]]),
                        "".join([x if 32 <= x <= 126 else "." for x in data[:4]]))

if sys.argv[1] == "lockbits":
    lookup = {
        0b000: 32768,
        0b001: 16384,
        0b010: 8192,
        0b011: 4096,
        0b100: 2048,
        0b101: 1024,
        0b110: 512,
        0b111: 0
    }
    data = client.read_lock_bits()
    bsize = (data[3] << 2) + (data[2] << 1) + data[1]
    print "Flash memory is %s" % ("protected" if data[0] == 0 else "readable")
    print "Boot sector size: %d" % (lookup[bsize])
    if data[4] == 0:
        print "Page 0 is protected"
    else:
        if bsize == 0:
            print "Page 0 is protected due to Boot sector size"
        else:
            print "Page 0 is writeable"

if sys.argv[1] == "verify":
    f = sys.argv[2]
    start = 0x0000
    stop = 0x7fff
    if len(sys.argv) > 3:
        start = int(sys.argv[3], 16)
    if len(sys.argv) > 4:
        stop = int(sys.argv[4], 16)

    print "Verifying code from %06x to %06x as %s." % (start, stop, f)
    file = open(f, mode='rb')

    i = start
    bytes = 0
    while i <= stop:
        data = client.read_page(i / 256)
        print "Verified %06x." % i
        for j in data:
            if i <= stop:
                #bits|=ord(file.read(1))^ord(j);
                a = ord(file.read(1))
                if a != j:
                    print "%06x: %02x/%02x" % (i, a, j)
                    bytes += 1
            i += 1
        if bytes != 0:
            print "%i bytes don't match." % bytes

    file.close()

if sys.argv[1] == "dump":
    f = sys.argv[2]
    start = 0x0000
    stop = 0x7fff
    if len(sys.argv) > 3:
        start = int(sys.argv[3], 16)
    if len(sys.argv) > 4:
        stop = int(sys.argv[4], 16)

    print "Dumping code from %06x to %06x as %s." % (start, stop, f)
    file = open(f, mode='wb')

    i = start
    while i <= stop:
        data = client.read_page(i / 256)
        print "Dumped %06x -> %06x." % (i, i + 256)
        for j in data:
            if i <= stop:
                file.write(chr(j))
            i += 1
    file.close()

if sys.argv[1] == "writeinfodata":
    v = int(sys.argv[2], 16) & 0xffffffff
    if len(sys.argv) == 4:
        print "Setting oscillator to %dMHz" % int(sys.argv[3])
        client.set_write_cycle_time(int(sys.argv[3]))
    else:
        print "Defaulting to 32MHz oscillator"
        client.set_write_cycle_time()
    print "Writing infodata"
    client.write_infodata([v >> 24, (v >> 16) & 0xff, (v >> 8) & 0xff, v & 0xff])

if sys.argv[1] == "protect_flash":
    print "Preventing reading back flash memory"
    client.protect_flash_memory()

if sys.argv[1] == "protect_page0":
    print "Setting Page0 to read-only"
    client.protect_page0()

if sys.argv[1] == "bootsize":
    bsize = 0
    if sys.argv[2][:2].lower() == "0x":
        bsize = int(sys.argv[2], 16)
    else:
        bsize = int(sys.argv[2])
    print "Setting bootsector size to %d bytes" % bsize
    client.set_bootsector_size(bsize)

if sys.argv[1] == "flash":
    f = sys.argv[2]
    start = 0x0000
    stop = 0x7fff

    if len(sys.argv) > 3:
        print "Setting oscillator to %dMHz" % int(sys.argv[3])
        client.set_write_cycle_time(int(sys.argv[3]))
    else:
        print "Defaulting to 32MHz oscillator"
        client.set_write_cycle_time()

    if len(sys.argv) > 4:
        start = int(sys.argv[3], 16)
    if len(sys.argv) > 5:
        stop = int(sys.argv[4], 16)

    print "Flashing code from %06x to %06x with %s." % (start, stop, f)
    file = open(f, mode='rb')

    i = start
    chars = list(file.read())
    chunksize = 0x100

    while i <= stop:
        bytes = [ord(chars[i + j]) for j in range(0, chunksize)]
        client.write_page(i / chunksize, bytes)

        i += chunksize
        if i % 0x1000 == 0:
            print "Flashed %06x." % i

    file.close()

if sys.argv[1] == "erase":
    if len(sys.argv) not in (3, 4) or (len(sys.argv) in (4, 5) and sys.argv[2] != "page"):
        print "%s erase all" % sys.argv[0]
        print "%s erase flash" %sys.argv[0]
        print "%s erase page $page_number" % sys.argv[0]
        sys.exit()
    if sys.argv[2] == "all":
        if len(sys.argv) == 4:
            print "Setting oscillator to %dMHz" % int(sys.argv[3])
            client.set_write_cycle_time(int(sys.argv[3]))
        else:
            print "Defaulting to 32MHz oscillator"
        client.set_write_cycle_time()
        print "Erasing both flash and info"
        client.chip_erase()
    if sys.argv[2] == "flash":
        if len(sys.argv) == 4:
            print "Setting oscillator to %dMHz" % int(sys.argv[3])
            client.set_write_cycle_time(int(sys.argv[3]))
        else:
            print "Defaulting to 32MHz oscillator"
            client.set_write_cycle_time()
        print "Erasing flash memory"
        client.program_memory_erase()
    if sys.argv[2] == "page":
        if len(sys.argv) == 5:
            print "Setting oscillator to %dMHz" % int(sys.argv[4])
            client.set_write_cycle_time(int(sys.argv[4]))
        else:
            print "Defaulting to 32MHz oscillator"
            client.set_write_cycle_time()
        page = int(sys.argv[3]) & 0x7f
        print "Erasing page %d (0x%04x to 0x%04x)" % (page, 256 * page, 256 * (page + 1))
        client.page_erase(page)

